import java.util.*;

public class Solution2316 {
    class DisjointSet<T>{
        int count = 0;

        Map<T, Long> jointSizeMapper = new HashMap<>();
        Map<T, T> preMapper = new HashMap<>();

        public T getRepresent(T item){
            T preItem = preMapper.get(item);
            if(preItem == null){
                return null;
            }
            if(Objects.equals(preItem, item)){
                return item;
            }
            T represent = getRepresent(preItem);
            preMapper.put(item, represent);
            return represent;
        }
        
        public void addJoint(T item){
            preMapper.put(item, item);
            jointSizeMapper.put(item, 1L);
            count++;
        }

        public void addJoint(T a, T b){
            T aRepresent = getRepresent(a);
            T bRepresent = getRepresent(b);
            if(!Objects.equals(aRepresent, bRepresent)){
                preMapper.put(bRepresent, aRepresent);
                jointSizeMapper.put(aRepresent, jointSizeMapper.get(aRepresent) + jointSizeMapper.get(bRepresent));
                jointSizeMapper.remove(bRepresent);
                count--;
            }
        }
    }

    public long countPairs(int n, int[][] edges) {
        DisjointSet<Integer> disjointSet = new DisjointSet<>();
        for (int i = 0; i < n; i++) {
            disjointSet.addJoint(i);
        }
        for (int[] edge : edges) {
            disjointSet.addJoint(edge[0], edge[1]);
        }

        long res = disjointSet.jointSizeMapper.values().stream().map(num -> num * (n - num)).reduce(0L, (a, b) -> {return a + b;}) / 2;
        return res;
    }


    public static void main(String[] args) {
        Solution2316 s = new Solution2316();
        s.countPairs(7, new int[][]{{0,2},{0,5},{2,4},{1,6},{5,4}});
    }
}
